import numpy as np 
from scipy.integrate import odeint
from scipy.optimize import fmin
import matplotlib.pyplot as plt

def H(A, theta):
    rh =  0 * A
    rh[A < theta] = 1
    h = 0 * A
    h[1:-1] = rh[0:-2] + rh[2:]
    h[0] = rh[1] +1 
    h[-1] = rh[-1] +1
    return h

def dAdI(X, t, r, mu, rho, theta, delta):
    n = X.shape[0]
    A = X[ :n // 2]
    I = X[ n//2: ]
    dA = (r/(1+I) -  A) / rho 
    dI = (mu * H(A, theta) - I)/delta
   # dA[ A > theta ] = 0
    dY = np.hstack((dA, dI))

    return dY


def dA(X, t, r, mu, rho, theta):
    A = X.copy()
    I = mu * H(A, theta) 
    dA = (r/(1+I) -  A) / rho 
    dA[ A > theta ] = 0
    dY = dA
    return dY


def xhi2_times(k, cpos, ctimes):
    times = get_times2(k, cpos)
    # n = cpos.shape[0]
    # A0 = np.zeros(n)
    # r0 = np.where(cpos == 0)[0]
    # A0[r0] = 0.5
    # r = k[0]
    # rho  = k[1]
    # mu = k[2]
    # delta = k[3]
    # theta = 1
    # tmax = 500
    # tspan = np.linspace(0, tmax, tmax)
    # I0 = mu + np.zeros(n)
    # X0 = np.hstack((A0, I0))
    # Y = odeint(dAdI,X0, tspan,args=(r, mu, rho, theta, delta))
    # times = np.zeros(n)
    # for x in range(n):
    #     wx = np.where(Y[:, x]>=1)[0]
    #     if len(wx>0):
    #         times[x] = tspan[wx[0]]
    #     else: times[x] = tmax
    # times = times - np.min(times)
    times = times.reshape(ctimes.shape)
    xhi2 = np.sum((times-ctimes)**2)
    print(k, xhi2)
    return xhi2 


def xhi2_timesrl(p, cpos, ctimes):
    k = [2.1,p[0], p[1], 1.0]
    times = []
    for n in range(10):
        times.append(get_timesr(k, cpos))
    times = np.stack(times)
    times = np.mean(times, 0)
    times = times.reshape(ctimes.shape)
    times_p = times[cpos>=0]
    ctimes_p = ctimes[cpos>=0]
    xhi2 = np.sum((times_p-ctimes_p)**2)
    print(k, xhi2)
    return xhi2 

def xhi2_timesr(k, cpos, ctimes):
    times = []
    for n in range(10):
        times.append(get_timesr(k, cpos))
    times = np.stack(times)
    times = np.mean(times, 0)
    times = times.reshape(ctimes.shape)
    xhi2 = np.sum((times-ctimes)**2)
    print(k, xhi2)
    return xhi2 


def get_times(k, cpos):
    n = cpos.shape[0]
    A0 = np.zeros(n)
    r0 = np.where(cpos == 0)[0]
    A0[r0] = 2.0
    r = k[0]
    rho  = k[1]
    mu = k[2]
    delta = k[3]
    theta = 1
    tmax = 500
    tspan = np.linspace(0, tmax, tmax)
    I0 = mu + np.zeros(n)
    X0 = np.hstack((A0, I0))
    Y = odeint(dAdI,X0, tspan,args=(r, mu, rho, theta,  delta))
    times = np.zeros(n)
    for x in range(n):
        wx = np.where(Y[:, x]>=1)[0]
        if len(wx>0):
            times[x] = tspan[wx[0]]
        else: times[x] = tmax
    times = times - np.min(times)
    return times

def get_timesr(k, cpos):
    n = cpos.shape[0]
    A0 = 0.5 * np.random.rand(n)
    r0 = np.where(cpos == 0)[0]
    A0[r0] = 2.0
    r = k[0]
    rho  = k[1]
    mu = k[2]
    delta = k[3]
    theta = 1
    tmax = 500
    tspan = np.linspace(0, tmax, tmax)
    I0 = mu + np.zeros(n)
    X0 = np.hstack((A0, I0))
    Y = odeint(dAdI,X0, tspan,args=(r, mu, rho, theta,  delta))
    times = np.zeros(n)
    for x in range(n):
        wx = np.where(Y[:, x]>=1)[0]
        if len(wx>0):
            times[x] = tspan[wx[0]]
        else: times[x] = tmax
    times = times - np.min(times)
    return times

def get_times2(k, cpos):
    n = cpos.shape[0]
    A0 = np.zeros(n)
    r0 = np.where(cpos == 0)[0]
    A0[r0] = 2.0
    r = k[0]
    rho  = k[1]
    mu = k[2]
    theta = 1
    tmax = 200
    tspan = np.linspace(0, tmax, 200)
    X0 = A0.copy()
    Y = odeint(dA,X0, tspan,args=(r, mu, rho, theta))
    times = np.zeros(n)
    for x in range(n):
        wx = np.where(Y[:, x]>=1)[0]
        if len(wx>0):
            times[x] = tspan[wx[0]]
        else: times[x] = tmax
    times = times - np.min(times)
    return times
    
def xhi2_times2(k, cpos, ctimes):
    times = get_times2(k, cpos)
    times = times.reshape(ctimes.shape)
    xhi2 = np.sum((times-ctimes)**2)
    print(k, xhi2)
    return xhi2 

def compare(k, cpos, ctimes):
    times = get_times(k, cpos)
    plt.plot(cpos, ctimes)
    plt.plot(cpos, times)
    plt.show()


def comparer(k, cpos, ctimes, ix, name):
    times = []
    for n in range(10):
        times.append(get_timesr(k, cpos))
    times = np.stack(times)
    times = np.mean(times, 0)
    plt.plot(cpos, ctimes)
    plt.plot(cpos[times<200], times[times<200])
    plt.savefig(f'{name}_{ix}.pdf')
    #plt.show()
    plt.close()
    return times

def compare2(k, cpos, ctimes):
    times =get_times2(k, cpos)
    plt.plot(cpos, ctimes)
    plt.plot(cpos, times)
    plt.show()


# r = 5
# theta = 1 
# rho = 150
# delta = 1
# mu = 1
# tmax = 1000
# tspan = np.linspace(0, tmax, 200)
# n =  10 
# A0 = np.zeros(n)
# A0[ n//2 ] = 0.5
# I0 = mu + np.zeros(n)
# X0 = np.hstack((A0, I0))
# print(X0.shape)
# Y = odeint(dAdI,X0, tspan,args=(r, mu, rho, theta, delta))


# times = np.zeros(n)
# for x in range(n):
#     wx = np.where(Y[:, x]>=1)[0]
#     if len(wx>0):
#         times[x] = tspan[wx[0]]
#     else: times[x] = tmax

# plt.plot(range(n), times - np.min(times))
# plt.show()

import pandas as pd  


def get_data(file):
    df = pd.read_csv(file, sep='\t')
    ns = df['ID'].unique() 
    sides = ['D','G']
    lines = df['LINE'].unique()
    data = {}
    for line in lines:
        rf = df[df['LINE']==line]   
        if line not in data.keys():
            data[line] = {}
        for side in sides:
            rs = rf[rf['SIDE']==side]
            for ind in ns:
                rx = rs[rs['ID']==ind]
                positions = rx['CPOS']
                ctimes = np.array(rx['CTIME'])
                for ix, p in enumerate(positions):                    
                    if p not in data[line].keys():
                        data[line][p] = []
                    data[line][p].append(ctimes[ix])
    return data

def get_mean(data_mean1):    
    pos = np.array(sorted(data_mean1.keys()))
    times = np.array([data_mean1[u] for u in sorted(data_mean1)])
    return data_mean1, pos, times

def get_fit(data_set, name, start=None):
    fits = {}
    for ix, data_mean in enumerate(data_set):
        print('line {}'.format(ix+1))
        pos = np.array(sorted(data_mean.keys()))
        times = np.array([data_mean[u] for u in sorted(data_mean.keys())])
        if start is not None:
            k0 = start 
        else:
            k0 = [36.32, 0.502]
        k_opt = fmin(xhi2_timesrl, k0, args=(pos, times))
        print(ix+1, k_opt)
        fits[ix+1] = {}
        fits[ix+1]['k'] = k_opt
        fits[ix+1]['values'] = comparer([2.1, k_opt[0], k_opt[1], 1], pos, times, ix+1, name + '2')
    return fits

names = {'sca_delta':'sca_delta_L1.tsv', 'scabp':'sca_L1.tsv', 'delta':'delta_L1.tsv'}
starts = {'sca_delta':[41.46607812500001, 0.35259269531249999], 
            'delta':[21.64515625916353, 0.53692343740854632], #21.462409068776488, 0.536649613032369
            'scabp': [27.9938681,  0.524669], #27.589092873662338, 0.5253708585222865
            }


fits = {}
for name in names:
    file = names[name]
    print(name, file)
    data = get_data(file)
    data_set = [{u: np.mean(np.array(data[k][u])) for u in sorted(data[k])}   for k in data.keys()]     
    if name in starts.keys():
        fits[name] = get_fit(data_set,name, starts[name])
    else:
        fits[name] = get_fit(data_set, name, None)

import pickle
with open('fits3.pk', 'wb') as iof:
    pickle.dump(fits, iof)


# sca delta 41.39514317  0.35844966 41.27170588  0.35810196 41.21084096  0.35035748
# scabp 38.00410751  0.43251426 24.55118666  0.53573437 29.6086487   0.48629203  27.30685966  0.5298434
# delta 22.15712256  0.53601104  23.48924493  0.5270235  22.72912081  0.52719067 39.53096529  0.44475856